Kth smallest element in a bst

Time: O(max(H,K)); Space: O(min(H,K)); medium

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:

  • You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:

Input: root = {TreeNode} [3,1,4,null,2], k = 1

  3
 / \
1   4
 \
  2

Output: 1

Example 2:

Input: root = {TreeNode} [5,3,6,2,4,null,null,1], k = 3

      5
     / \
    3   6
   / \
  2   4
 /
1

Output: 3

Follow up:

  • What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently?

  • How would you optimize the kthSmallest routine?

Hints:

  1. Try to utilize the property of a BST.

  2. Try in-order traversal. (Credits to @chan13)

  3. What if you could modify the BST node’s structure?

  4. The optimal runtime complexity is O(height of BST).

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
[2]:
class Solution1(object):
    """
    Time: O(max(H,K))
    Space: O(H)
    """
    def kthSmallest(self, root, k):
        """
        :type root: TreeNode
        :type k: int
        :rtype: int
        """
        s, cur, rank = [], root, 0

        while s or cur:
            if cur:
                s.append(cur)
                cur = cur.left
            else:
                cur = s.pop()
                rank += 1
                if rank == k:
                    return cur.val
                cur = cur.right

        return float("-inf")
[3]:
s = Solution1()

root = TreeNode(3)
root.left = TreeNode(1)
root.right = TreeNode(4)
root.left.right = TreeNode(2)
k = 1
assert s.kthSmallest(root, k) == 1

root = TreeNode(5)
root.left = TreeNode(3)
root.right = TreeNode(6)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.left.left.left = TreeNode(1)
k = 3
assert s.kthSmallest(root, k) == 3
[4]:
from itertools import islice

class Solution2(object):
    """
    Time: O(max(h, k))
    Space: O(h)
    """
    def kthSmallest(self, root, k):
        """
        :type root: TreeNode
        :type k: int
        :rtype: int
        """
        def gen_inorder(root):
            if root:
                for n in gen_inorder(root.left):
                    yield n

                yield root.val

                for n in gen_inorder(root.right):
                    yield n

        return next(islice(gen_inorder(root), k-1, k))
[5]:
s = Solution2()

root = TreeNode(3)
root.left = TreeNode(1)
root.right = TreeNode(4)
root.left.right = TreeNode(2)
k = 1
assert s.kthSmallest(root, k) == 1

root = TreeNode(5)
root.left = TreeNode(3)
root.right = TreeNode(6)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.left.left.left = TreeNode(1)
k = 3
assert s.kthSmallest(root, k) == 3